// Copyright © 2023 Cisco Systems, Inc. and its affiliates.
// All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package yara

import (
	"context"
	"errors"
	"fmt"
	"io"
	"net/http"
	"net/url"
	"os"
	"os/exec"
	"path"
	"path/filepath"
	"runtime"
	"strconv"
	"strings"

	ruleserverconfig "github.com/openclarity/yara-rule-server/pkg/config"
	"github.com/openclarity/yara-rule-server/pkg/rules"
	"github.com/sirupsen/logrus"

	"github.com/openclarity/openclarity/core/log"
	"github.com/openclarity/openclarity/scanner/common"
	"github.com/openclarity/openclarity/scanner/families"
	"github.com/openclarity/openclarity/scanner/families/malware/types"
	"github.com/openclarity/openclarity/scanner/families/malware/yara/config"
	yarautils "github.com/openclarity/openclarity/scanner/families/malware/yara/utils"
	familiesutils "github.com/openclarity/openclarity/scanner/families/utils"
	"github.com/openclarity/openclarity/scanner/utils"
)

const ScannerName = "yara"

type Scanner struct {
	config           config.Config
	compiledRuleFile string
}

func New(ctx context.Context, _ string, config types.ScannersConfig) (families.Scanner[*types.ScannerResult], error) {
	logger := log.GetLoggerFromContextOrDefault(ctx)

	// Download compiled yara rule
	compiledRuleFile, err := getCompiledRuleFilePath(config.Yara, logger)
	if err != nil {
		return nil, fmt.Errorf("failed to get compiled rule file path: %w", err)
	}

	return &Scanner{
		config:           config.Yara,
		compiledRuleFile: compiledRuleFile,
	}, nil
}

// nolint: gocognit,cyclop
func (s *Scanner) Scan(ctx context.Context, inputType common.InputType, userInput string) (*types.ScannerResult, error) {
	if !inputType.IsOneOf(common.ROOTFS, common.DIR, common.IMAGE, common.DOCKERARCHIVE, common.OCIARCHIVE, common.OCIDIR) {
		return nil, fmt.Errorf("unsupported input type=%v", inputType)
	}

	logger := log.GetLoggerFromContextOrDefault(ctx)

	yaraBinaryPath, err := exec.LookPath(s.config.GetYaraBinaryPath())
	if err != nil {
		return nil, fmt.Errorf("failed to lookup executable %s: %w", s.config.YaraBinaryPath, err)
	}
	logger.Debugf("found yara binary at: %s", yaraBinaryPath)

	logger.Debugf("Yara rules URL: %s", s.config.CompiledRuleURL)
	logger.Debugf("Yara rules file path: %s", s.compiledRuleFile)

	fsPath, cleanup, err := familiesutils.ConvertInputToFilesystem(ctx, inputType, userInput)
	if err != nil {
		return nil, fmt.Errorf("failed to convert input to filesystem: %w", err)
	}
	defer cleanup()

	// Process function that parses each line of yara output
	var detectedMalware []types.Malware
	var parseErrSamples, scanErrSamples []error
	var outputLines, parseErrorCount uint
	parserFunc := func(line string) {
		outputLines++
		malware, err := yarautils.ParseYaraScanOutput(line)
		if err != nil {
			var invalidLineErr *yarautils.InvalidLineError
			if errors.Is(err, invalidLineErr) {
				logger.Debugf("Omitting invalid yara output line: %v", err)
			} else {
				logger.Errorf("Error parsing yara output line: %v", err)
			}
			if parseErrorCount < yarautils.SampleScanErrorNum {
				parseErrSamples = append(parseErrSamples, err)
			}
			parseErrorCount++
			return
		}
		if malware != nil {
			detectedMalware = append(detectedMalware, *malware)
		}
	}
	var scanErrCount uint
	errCheckFunc := func(line string) {
		logger.Debugf("Error occurred during yara scan: %s", line)
		// Count only the scanner errors to avoid count on status messages.
		if strings.Contains(line, yarautils.ScanError) {
			if scanErrCount < yarautils.SampleScanErrorNum {
				scanErrSamples = append(scanErrSamples, errors.New(line))
			}
			scanErrCount++
		}
	}

	// Define the yara args to run
	args := []string{"-C", s.compiledRuleFile, "-r", "-w", "-m", "-p", strconv.Itoa(runtime.NumCPU())}
	logger.Infof("Running yara...")

	// TODO(ramizpolic): It is unnecessary to pass directories to scan through
	// config since we already have this. We should remove this.
	if len(s.config.DirectoriesToScan) == 0 {
		yaraCommand := exec.CommandContext(ctx, yaraBinaryPath, fsPath)
		err = utils.RunCommandAndParseOutputLineByLine(yaraCommand, parserFunc, errCheckFunc)
		if err != nil {
			return nil, fmt.Errorf("failed to run yara command: %w", err)
		}
	} else {
		for _, d := range s.config.DirectoriesToScan {
			pathToScan := filepath.Join(fsPath, d)
			yaraCommand := exec.CommandContext(ctx, yaraBinaryPath, append(args, pathToScan)...)
			err = utils.RunCommandAndParseOutputLineByLine(yaraCommand, parserFunc, errCheckFunc)
			if err != nil {
				return nil, fmt.Errorf("failed to run yara command: %w", err)
			}
		}
	}

	// If the stderr lines / stderr lines + stdout lines is greater than the `errThreshold` the error threshold will be reached.
	if yarautils.IsErrorThresholdReached(scanErrCount, outputLines+scanErrCount) {
		return nil, fmt.Errorf(
			"scanner error threshold (%.2f%%) is reached, nuber of errors=%d, sample of errors: %w",
			yarautils.ErrThreshold*100, // nolint:mnd
			scanErrCount,
			errors.Join(scanErrSamples...),
		)
	}

	// If the parse failures / lines parsed is greater than the `errThreshold` the error threshold will be reached.
	if yarautils.IsErrorThresholdReached(parseErrorCount, outputLines) {
		return nil, fmt.Errorf(
			"output parsing error threshold (%.2f%%) is reached, number of errors=%d, sample of errors: %w",
			yarautils.ErrThreshold*100, // nolint:mnd
			parseErrorCount,
			errors.Join(parseErrSamples...),
		)
	}

	return &types.ScannerResult{
		Source:   userInput,
		Malwares: detectedMalware,
		// TODO: We should calculate scan summary somewhere
		Summary: &types.ScanSummary{},
	}, nil
}

// nolint: cyclop
func getCompiledRuleFilePath(cfg config.Config, logger *logrus.Entry) (string, error) {
	// If both compiled rule url and raw rule sources are defined we choose the compiled one.
	if cfg.CompiledRuleURL != "" {
		parsed, err := url.Parse(cfg.CompiledRuleURL)
		if err != nil {
			return "", fmt.Errorf("failed to parse Yara rule URL=%s: %w", cfg.CompiledRuleURL, err)
		}

		switch parsed.Scheme {
		case "file":
			return parsed.Path, nil
		case "http", "https":
			return downloadCompiledRules(cfg.CompiledRuleURL, logger)
		default:
			return "", errors.New("unsupported Yara rules URL")
		}
	}

	if len(cfg.RuleSources) != 0 {
		var err error
		cacheDir := cfg.CacheDir

		yaracBinaryPath, err := exec.LookPath(cfg.GetYaracBinaryPath())
		if err != nil {
			return "", fmt.Errorf("failed to lookup executable %s: %w", cfg.YaracBinaryPath, err)
		}
		logger.Debugf("found yarac binary at: %s", yaracBinaryPath)

		if cacheDir == "" {
			cacheDir, err = createCacheDir()
			if err != nil {
				return "", fmt.Errorf("failed to create cache directory: %w", err)
			}
		}
		if err = rules.DownloadAndCompile(&ruleserverconfig.Config{
			RuleSources: cfg.RuleSources,
			YaracPath:   yaracBinaryPath,
			CacheDir:    cacheDir,
		}, logger); err != nil {
			return "", fmt.Errorf("failed to download and compile raw rules: %w", err)
		}
		return path.Join(cacheDir, ruleserverconfig.CompiledRuleFileName), nil
	}

	return "", errors.New("neither compiled rule URL nor rule sources are defined")
}

func downloadCompiledRules(compiledRuleURL string, logger *logrus.Entry) (string, error) {
	logger.Infof("Downloading Yara rules...")

	out, err := os.CreateTemp(os.TempDir(), "compiledYaraRules")
	if err != nil {
		return "", fmt.Errorf("failed to create ouptut file=%s: %w", out.Name(), err)
	}
	defer out.Close()

	resp, err := http.Get(compiledRuleURL) // nolint:noctx,gosec
	if err != nil {
		return "", fmt.Errorf("failed to get url: %w", err)
	}
	defer resp.Body.Close()

	if resp.StatusCode != http.StatusOK {
		return "", fmt.Errorf("failed to get url=%s: %s", compiledRuleURL, resp.Status)
	}

	_, err = io.Copy(out, resp.Body)
	if err != nil {
		return "", fmt.Errorf("failed to write file=%s: %w", out.Name(), err)
	}
	logger.Infof("compiled rule path: %s", out.Name())

	return out.Name(), nil
}

func createCacheDir() (string, error) {
	cacheDir, err := os.UserCacheDir()
	if err != nil {
		return "", fmt.Errorf("unable to determine os cache directory: %w", err)
	}

	cacheDir = path.Join(cacheDir, ScannerName)
	if err = os.MkdirAll(cacheDir, os.ModePerm); err != nil {
		return "", fmt.Errorf("directory creation failed, dir=%s: %w", cacheDir, err)
	}

	return cacheDir, nil
}
